package top.zhenganwen.security.core.verifycode.filter;

import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
import org.springframework.social.connect.web.HttpSessionSessionStrategy;
import org.springframework.social.connect.web.SessionStrategy;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.context.request.ServletWebRequest;
import org.springframework.web.filter.OncePerRequestFilter;
import top.zhenganwen.security.core.properties.SecurityProperties;
import top.zhenganwen.security.core.verifycode.exception.VerifyCodeException;
import top.zhenganwen.security.core.verifycode.po.VerifyCode;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import static top.zhenganwen.security.core.verifycode.constont.VerifyCodeConstant.SMS_CODE_SESSION_KEY;

/**
 * @author zhenganwen
 * @date 2019/8/24
 * @desc VerifyCodeAuthenticationFilter
 */
//@Component
public class SmsCodeAuthenticationFilter extends OncePerRequestFilter implements InitializingBean {

    private AuthenticationFailureHandler authenticationFailureHandler;

    private SessionStrategy sessionStrategy = new HttpSessionSessionStrategy();

    private SecurityProperties securityProperties;

    private Set<String> uriPatternSet = new HashSet<>();

    // uri匹配工具类，帮我们做类似/user/1到/user/*的匹配
    private AntPathMatcher antPathMatcher = new AntPathMatcher();

    @Override
    public void afterPropertiesSet() throws ServletException {
        super.afterPropertiesSet();
        String uriPatterns = securityProperties.getCode().getSms().getUriPatterns();
        if (StringUtils.isNotBlank(uriPatterns)) {
            String[] strings = StringUtils.splitByWholeSeparatorPreserveAllTokens(uriPatterns, ",");
            uriPatternSet.addAll(Arrays.asList(strings));
        }
        uriPatternSet.add("/auth/sms");
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException,
            IOException {
        for (String uriPattern : uriPatternSet) {
            // 有一个匹配就需要拦截 校验验证码
            if (antPathMatcher.match(uriPattern, request.getRequestURI())) {
                try {
                    this.validateVerifyCode(new ServletWebRequest(request));
                } catch (VerifyCodeException e) {
                    authenticationFailureHandler.onAuthenticationFailure(request, response, e);
                    return;
                }
                break;
            }
        }
        filterChain.doFilter(request, response);
    }

    // 拦截用户登录的请求，从Session中读取保存的短信验证码和用户提交的验证码进行比对
    private void validateVerifyCode(ServletWebRequest request){
        String smsCode = (String) request.getParameter("smsCode");
        if (StringUtils.isBlank(smsCode)) {
            throw new VerifyCodeException("验证码不能为空");
        }
        VerifyCode verifyCode = (VerifyCode) sessionStrategy.getAttribute(request, SMS_CODE_SESSION_KEY);
        if (verifyCode == null) {
            throw new VerifyCodeException("验证码不存在");
        }
        if (verifyCode.isExpired()) {
            throw new VerifyCodeException("验证码已过期，请刷新页面");
        }
        if (StringUtils.equals(smsCode,verifyCode.getCode()) == false) {
            throw new VerifyCodeException("验证码错误");
        }
        sessionStrategy.removeAttribute(request, SMS_CODE_SESSION_KEY);
    }

    public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
        this.authenticationFailureHandler = authenticationFailureHandler;
    }

    public void setSecurityProperties(SecurityProperties securityProperties) {
        this.securityProperties = securityProperties;
    }

    public AuthenticationFailureHandler getAuthenticationFailureHandler() {
        return authenticationFailureHandler;
    }

    public SecurityProperties getSecurityProperties() {
        return securityProperties;
    }
}
